import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math
from matplotlib.ticker import MaxNLocator

matplotlib.use('Agg')

sns.set_style("whitegrid")
plt.rcParams["font.family"] = 'DejaVu Sans'


model1_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': ['ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.66198, 0.73684, 0.82704, 0.88112, 0.90256],
    'mean_num_facts': [26.408, 19.202, 14.766, 10.536, 7.18999],
}


model1_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': ['ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.65206, 0.69866, 0.76696, 0.82796, 0.87442],
    'mean_num_facts': [27.38999, 18.184, 11.47, 7.418, 4.854]
}


model1_other_methods_data = {
    'method': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'baseline': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'threshold': [None] * 6,
    'mean_factscore': [0.69, 0.68, 0.69, 0.70, 0.69, 0.56],
    'mean_num_facts': [18.03, 17.70, 20.59, 18.99, 18.53, 18.87]
}



model2_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': [
        'ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'
    ],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.7634, 0.9200, 0.9340, 0.9334, 0.9411],
    'mean_num_facts': [15.1640, 10.2580, 7.9040, 5.1180, 3.1680],  
}


model2_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': [
        'ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'
    ],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.7021, 0.8355, 0.9007, 0.9308, 0.9531],
    'mean_num_facts': [16.4060, 8.8020, 5.3980, 3.3620, 1.8760],  
}


model2_other_methods_data = {
    'method': [
        'Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'
    ],
    'baseline': [
        'Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'
    ],
    'threshold': [None] * 6,
    'mean_factscore': [0.7877, 0.8508, 0.8065, 0.8162, 0.7885, 0.5704],
    'mean_num_facts': [8.4720, 8.1900, 12.3340, 9.9020, 9.6632, 19.4240],  
}



model3_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': ['ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.7104, 0.8294, 0.8853, 0.9303, 0.9607],
    'mean_num_facts': [36.186, 24.72, 15.504, 9.232, 4.708],
}

model3_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': ['ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.6978, 0.7716, 0.8492, 0.9138, 0.9646],
    'mean_num_facts': [38.936, 17.854, 7.808, 4.034, 1.586]
}

model3_other_methods_data = {
    'method': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'baseline':  ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'threshold': [None] * 6,
    'mean_factscore': [0.7493, 0.7768, 0.7580, 0.7586, 0.7453, 0.5704],
    'mean_num_facts': [24.29, 24.564, 25.224, 25.338, 25.3344, 19.424]
}


fig, axes = plt.subplots(1, 3, figsize=(43.2, 10.8))


# [https://colorbrewer2.org](https://colorbrewer2.org)
colors = {
    'ConGrs': '#bc80bd',      # Light teal
    'ASC': '#8dd3c7',         # Dark purple
    'Greedy': '#bebada',      # Light purple
    'Shortest': '#fb8072',    # Light red
    'LM consensus': '#fccde5', # Light pink
    'MBR': '#fdb462',         # Light orange
    'Mean of m': '#b3de69',   # Light green
    'QwQ 32B': '#80b1d3'      # Light blue
}


internal_name_to_display_name = {
    'CONGRS (0.1)': 'ConGrs',
    'CONGRS (0.3)': 'ConGrs',
    'CONGRS (0.5)': 'ConGrs',
    'CONGRS (0.7)': 'ConGrs',
    'CONGRS (0.9)': 'ConGrs',
    'Temp 0': 'Greedy',
    'Short Resp': 'Shortest',
    'LLM Cons w Abs': 'LM consensus',
    'MBR': 'MBR',
    'Mean of N': 'Mean of m',
    'Qwen QWQ 32B': 'QwQ 32B',
    'ASC (Threshold = 1)': 'ASC',
    'ASC (Threshold = 2)': 'ASC',
    'ASC (Threshold = 3)': 'ASC',
    'ASC (Threshold = 4)': 'ASC',
    'ASC (Threshold = 5)': 'ASC',
}


internal_name_to_zorder = {
    'CONGRS (0.1)': 6,
    'CONGRS (0.3)': 6,
    'CONGRS (0.5)': 6,
    'CONGRS (0.7)': 6,
    'CONGRS (0.9)': 6,
    'Temp 0': 5,
    'Short Resp': 5,
    'LLM Cons w Abs': 5,
    'MBR': 5,
    'Mean of N': 5,
    'Qwen QWQ 32B': 5,
    'ASC (Threshold = 1)': 4,
    'ASC (Threshold = 2)': 4,
    'ASC (Threshold = 3)': 4,
    'ASC (Threshold = 4)': 4,
    'ASC (Threshold = 5)': 4,
}


marker_map = {'Greedy': '^', 'Shortest': 'v', 'LM consensus': 'D', 
              'MBR': 'p', 'Mean of m': 'h', 'QwQ 32B': 'X',
              'ASC': 's', 'ConGrs': 'o'}


def set_xaxis_limits_from_max_value(ax, df_congrs, df_asc, df_others):
    max_val = max(df_congrs['mean_num_facts'].max(), 
                  df_asc['mean_num_facts'].max(), 
                  df_others['mean_num_facts'].max())
    max_val_rounded = math.ceil(max_val)
    ax.set_xlim(0, max_val_rounded)


def create_subplot_with_errorbars(ax, subplot_idx, model_name, congrs_data, asc_data, other_methods_data, csv_file):
    df_congrs = pd.DataFrame(congrs_data)
    df_asc = pd.DataFrame(asc_data)
    df_others = pd.DataFrame(other_methods_data)
    df_all = pd.concat([df_congrs, df_asc, df_others], ignore_index=True)
    
    congrs_df = df_all[df_all['method'] == 'ConGrs'].sort_values('threshold')
    ax.plot(congrs_df['mean_num_facts'], congrs_df['mean_factscore'], 
            color=colors['ConGrs'], linewidth=8, alpha=0.9, 
            linestyle='-', zorder=4)
    
    asc_df = df_all[df_all['method'] == 'ASC'].sort_values('threshold')
    ax.plot(asc_df['mean_num_facts'], asc_df['mean_factscore'], 
            color=colors['ASC'], linewidth=8, alpha=0.9, 
            linestyle='-', zorder=3)

    set_xaxis_limits_from_max_value(ax, df_congrs, df_asc, df_others)

    for method in df_all['method'].unique():
        method_data = df_all[df_all['method'] == method]
        
        if method == 'ConGrs':
            for _, row in method_data.iterrows():
                ax.annotate(rf'$\tau={row["threshold"]}$', 
                            (row['mean_num_facts'], row['mean_factscore']),
                            xytext=(14, 14), textcoords='offset points',
                            fontsize=24, fontweight='bold', alpha=0.9,
                            bbox=dict(boxstyle="round,pad=0.1",
                                        fc='#bc80bd', lw=0, alpha=0.15))
        
        elif method == 'ASC':        
            for _, row in method_data.iterrows():
                threshold_label = rf'$\Theta={int(row["threshold"])}$'
                if row['threshold'] == 1:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               ha='center', va='top',
                               xytext=(0, -20), textcoords='offset points',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))
                elif row['threshold'] == 5:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               xytext=(-8, -30), textcoords='offset points',
                               ha='center', va='top',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))
                else:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               xytext=(-15, -12), textcoords='offset points',
                               ha='right', va='top',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))
    
    # Plot points from CSV data with error bars
    try:
        errorbar_data = pd.read_csv(csv_file)
        for method in ['CONGRS (0.1)', 'CONGRS (0.3)', 'CONGRS (0.5)', 'CONGRS (0.7)', 'CONGRS (0.9)', 
                       'Temp 0', 'Short Resp', 'LLM Cons w Abs', 'MBR', 'Mean of N', 'Qwen QWQ 32B',  
                       'ASC (Threshold = 1)', 'ASC (Threshold = 2)', 'ASC (Threshold = 3)',  
                       'ASC (Threshold = 4)',  'ASC (Threshold = 5)']:
            this_df = errorbar_data.loc[(errorbar_data['Baselines'] == method)]
            if not this_df.empty:
                ax.errorbar(this_df['Mean number of supported facts'].mean(), this_df['Mean FActScore'].mean(),
                            xerr=this_df['Mean number of supported facts'].std(), yerr=this_df['Mean FActScore'].std(),
                            ecolor='#777', elinewidth=2, zorder=4, capsize=4, capthick=2)
                ax.scatter(this_df['Mean number of supported facts'].mean(), this_df['Mean FActScore'].mean(),
                          c=colors[internal_name_to_display_name[method]], s=500,
                          edgecolors='#fff',
                          marker=marker_map[internal_name_to_display_name[method]], 
                          zorder=internal_name_to_zorder[method])
                print(f"{model_name} - {method}: {this_df['Mean number of supported facts'].mean()}, {this_df['Mean FActScore'].mean()}")
    except FileNotFoundError:
        print(f"CSV file {csv_file} not found, plotting without error bars for {model_name}")
        # If CSV not found, plot from hardcoded data as fallback
        for method in df_all['method'].unique():
            method_data = df_all[df_all['method'] == method]
            for _, row in method_data.iterrows():
                zorder_val = 6 if method == 'ConGrs' else (4 if method == 'ASC' else 5)
                ax.scatter(row['mean_num_facts'], row['mean_factscore'],
                          c=colors[method], s=500, edgecolors='#fff',
                          marker=marker_map[method], zorder=zorder_val)
    
    ax.set_xlabel('Mean number of supported facts', fontsize=40, fontweight='bold')
    if subplot_idx == 0:  # Only add ylabel to the first subplot
        ax.set_ylabel('Mean FActScore', fontsize=40, fontweight='bold', labelpad=10)
    
    ax.set_ylim(0.55, 1.0)
    
    if subplot_idx == 1:
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    ax.tick_params(axis='both', which='major', labelsize=30, length=16, width=2)
    
    # Increased model name font size from 36 to 42
    ax.set_title(f'{model_name}', fontsize=42, fontweight='bold', pad=20)


model_names = ['QWEN 2.5 72B', 'LLAMA 3.3 70B', 'OLMO 2 32B']  
model_data = [
    (model1_congrs_data, model1_asc_data, model1_other_methods_data),
    (model2_congrs_data, model2_asc_data, model2_other_methods_data),
    (model3_congrs_data, model3_asc_data, model3_other_methods_data)
]


csv_files = [
    'graph-data/factscore-trade-off-qwen-bio-data.csv', 
    'graph-data/factscore-trade-off-llama-bio-data.csv',
    'graph-data/factscore-trade-off-olmo-bio-data.csv',
]


for i in range(3):
    congrs_data, asc_data, other_methods_data = model_data[i]
    create_subplot_with_errorbars(axes[i], i, model_names[i], congrs_data, asc_data, other_methods_data, csv_files[i])
   
fig.text(0.02, 0.5, 'Biography factuality', fontsize=60, fontweight='bold', 
         rotation=90, ha='center', va='center')

plt.tight_layout()
plt.subplots_adjust(left=0.08)  
plt.savefig('bio_subplots.pdf', bbox_inches='tight')